iT邦幫忙

2022 iThome 鐵人賽

DAY 9
0
AI & Data

菜鳥工程師第一個電腦視覺(CV)專案-農作物影像辨識競賽系列 第 9

D9-資料集前處理4th:資料集切分&DataLoader

  • 分享至 

  • xImage
  •  

Part1:今日目標

1.Day8(D8) 程式除錯(Debug)、資料集切分、Dataloader。


Part2:內容

1.D8程式除錯(Debug)、資料集切分、Dataloader

(1)Debug D8程式碼: (D8程式碼已更新)
針對 Class CustomImageDataset 的 def __getitem__(self, idx)做以下修改:

  • resize = transforms.Resize([self.width, self.height]): 圖片要重新設定大小,才能讓batch裡面的圖片大小都一樣,不然會報錯(ERROR: stack expects each tensor to be equal size, but got [3, 129, 190] at entry 0 and [3, 190, 190] at entry 1)。
  • label = self.img_labels.iloc[idx, 0]: 要指定0才能取得標籤字串。

(2)資料集切分&DataLoader
將整個Dataset切成三份資料,分別為訓練集(train)、驗證集(val)和測試集(test),各份資料比例為train:val:test=80%:10%:10%。

Step1: 將資料打亂洗牌(Shuffle)後,建立索引(indices)

batch_size = 10 # 一個batch有10張圖片
val_size, test_size = 0.1, 0.1  # train:val:test=0.8:0.1:0.1
shuffle_dataset = True
random_seed= 42

# Create data indices for train & validatin spilt
# Set seed and shuffle, then spilt data to train, validation, test
dataset_size = len(dataset)
indices = list(range(dataset_size))
val_spilt = int(np.floor(val_size * dataset_size))
test_spilt = val_spilt + int(np.floor(test_size * dataset_size))

if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
    
train_indices, val_indices, test_indices = indices[test_spilt:], indices[:val_spilt], indices[val_spilt:test_spilt]

Step2: 透過索引和SubsetRandomSampler()建立sampler,並利用DataLoader載入和操作資料。

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)


train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=val_sampler)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=test_sampler)



Step3: 嘗試操作確認打包後的data_loader,看是否有報錯

# Usage ex:
num_epochs = 2  # 抓2個epoch出來看
for epoch in range(num_epochs):
    print(epoch)
    try:
        for batch_index, (img, labels) in enumerate(test_loader):
            print(batch_index, (img, labels))
    except Exception as ee:
        print(ee)
# 目前沒有報錯

一個epoch會跑過197個batch(因為test_loader共有197個batch),每個batch有10張圖片。

# Debug:
ERROR: stack expects each tensor to be equal size, but got [3, 129, 190] at entry 0 and [3, 190, 190] at entry 1.
    * Reason: images are different size.
    * Sol: transforms.Resize((img_size, img_size))

2.今日新學習到的項目:

(1) torchvision.transforms

  • 對圖片資料做各種操作: 旋轉、評儀、影像縮放等等。

(2) torchvision.transforms.Resize([h, w]):

  • 調整圖片尺寸(Resize the input image to the given size)。
  • 類似用法: resize(img, size[, interpolation, max_size, …])。
  • 神經網路訓練時,建議所有圖片的大小要一致(It is preferable to train and serve a model with the same input types.)。

(3) torch.utils.data.SubsetRandomSampler(indices, generator=None)

  • 不放回的按照給定索引標籤進行資料選取(Samples elements randomly from a given list of indices, without replacement.)。

參考:

Part3:專案進度

  • 切分資料集,並且使用Dataloader載入資料。

Part4:下一步

  • 設定影像分類模型:CNN。
心得小語:
身體疲勞心又有點累的一周,終於過了呼~~~
得趁周末好好休生養息了(好老人言XDD
今日工時: 50mins*3

/images/emoticon/emoticon34.gif

繼續努力吧,你需要的一切將會在一個完美的時機到來
Keep going. Everything you need will come to you at the perfect time.


上一篇
D8-資料集前處理3rd:建立Dataset
下一篇
D10-卷積神經網路CNN_理論學習1st
系列文
菜鳥工程師第一個電腦視覺(CV)專案-農作物影像辨識競賽32
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言